package secrets

import (
	"bytes"
	"fmt"
	"log"
	"os"

	"apiote.xyz/p/eeze/crypto"
	"apiote.xyz/p/eeze/fido"
	"apiote.xyz/p/eeze/fs"

	"git.sr.ht/~sircmpwn/go-bare"
	"golang.org/x/crypto/argon2"
	"golang.org/x/crypto/ssh/terminal"
)

type PasswordCredential struct {
	CredenialName     string
	EncryptedPassword []byte
	Salt              []byte
}

func (PasswordCredential) IsUnion() {}
func (c PasswordCredential) Name() string {
	return c.CredenialName
}
func (c PasswordCredential) String() string {
	return c.CredenialName + ": Password"
}

type FidoCredential struct {
	AAGUID            []byte
	CredenialName     string
	Salt              []byte
	ClientDataHash    []byte
	CredentialID      []byte
	EncryptedPassword []byte
}

func (FidoCredential) IsUnion() {}
func (c FidoCredential) Name() string {
	return c.CredenialName
}
func (c FidoCredential) String() string {
	return c.CredenialName + ": FIDO2 key"
}

type Credential interface {
	bare.Union
	Name() string
}

type Key struct {
	bytes []byte
}

func readCredentials() ([]Credential, error) {
	dataHome := fs.DataHome()
	credentials := []Credential{}
	credentialsFile, err := os.ReadFile(dataHome + "/credentials.bare")
	if err != nil {
		return credentials, fmt.Errorf("while reading credentials: %w", err)
	}
	err = bare.Unmarshal(credentialsFile, &credentials)
	if err != nil {
		return credentials, fmt.Errorf("while unmarshalling credentials: %w", err)
	}
	return credentials, nil
}

func writeCredentials(credentials []Credential) error {
	dataHome := fs.DataHome()
	credentialsBytes, err := bare.Marshal(&credentials)
	if err != nil {
		return fmt.Errorf("while marshalling credentials: %w", err)
	}
	err = os.WriteFile(dataHome+"/credentials.bare", credentialsBytes, 0644)
	if err != nil {
		return fmt.Errorf("while writing credentials: %w", err)
	}
	return nil
}

func Open() (Key, error) {
	key := Key{}
	c, err := readCredentials()
	if err != nil {
		return key, fmt.Errorf("while getting credentials: %s", err)
	}
	key.bytes, err = openCredential(c)
	if err != nil {
		return key, fmt.Errorf("while getting key: %s", err)
	}
	return key, nil
}

func InitialiseCredentials() error {
	dataHome := fs.DataHome()
	err := os.MkdirAll(dataHome, 0755)
	if err != nil {
		return fmt.Errorf("while making home directory: %w", err)
	}
	err = os.MkdirAll(dataHome+"/secrets", 0755)
	if err != nil {
		return fmt.Errorf("while making secrets directory: %w", err)
	}
	credentials := []Credential{}
	credentialsBytes, err := bare.Marshal(&credentials)
	if err != nil {
		return fmt.Errorf("while marshalling credentials: %w", err)
	}
	file, err := os.OpenFile(dataHome+"/credentials.bare", os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0644)
	if err != nil {
		pathError := err.(*os.PathError)
		if pathError.Err.Error() == "file exists" {
			return nil
		} else {
			return fmt.Errorf("while creating credentials file: %w", err)
		}
	}
	defer file.Close()
	_, err = file.Write(credentialsBytes)
	if err != nil {
		return fmt.Errorf("while writing credentials: %w", err)
	}
	return nil
}

func AddCredential(typeFido, typePass bool, name string) error {
	credentials, err := readCredentials()
	if err != nil {
		return fmt.Errorf("while reading credentials: %w", err)
	}
	var key []byte
	if len(credentials) == 0 {
		key = crypto.MakeKey()
	} else {
		key, err = openCredential(credentials)
		if err != nil {
			return fmt.Errorf("while getting key: %w", err)
		}
	}
	var credential Credential
	if typeFido {
		credential, err = addFidoCred(key, credentials, name)
	} else if typePass {
		credential, err = addPasswordCred(key, name)
	} else {
		err = fmt.Errorf("neither type passed")
	}
	if err != nil {
		return fmt.Errorf("while adding credential: %w", err)
	}
	credentials = append(credentials, credential)
	err = writeCredentials(credentials)
	if err != nil {
		return fmt.Errorf("while writing credentials: %w", err)
	}
	return nil
}

func DelCredential(name string) error {
	credentials, err := readCredentials()
	if err != nil {
		return fmt.Errorf("while reading credentials: %w", err)
	}
	if len(credentials) == 0 {
		return fmt.Errorf("no credentials")
	}
	newCredentials := []Credential{}
	for _, credential := range credentials {
		if credential.Name() != name {
			newCredentials = append(newCredentials, credential)
		}
	}
	err = writeCredentials(newCredentials)
	if err != nil {
		return fmt.Errorf("while writing credentials: %w", err)
	}
	return nil
}

func ListCredentials() error {
	credentials, err := readCredentials()
	if err != nil {
		return fmt.Errorf("while reading credentials: %w", err)
	}
	if len(credentials) == 0 {
		return fmt.Errorf("no credentials")
	}
	for _, credential := range credentials {
		fmt.Println(credential)
	}
	return nil
}

func addPasswordCred(key []byte, name string) (Credential, error) {
	var err error
	credential := PasswordCredential{
		Salt:          crypto.MakeKey(),
		CredenialName: name,
	}
	fmt.Print("Type the password: ")
	password1, err := terminal.ReadPassword(int(os.Stdin.Fd()))
	if err != nil {
		return PasswordCredential{}, fmt.Errorf("while reading password from stdin: %w", err)
	}
	fmt.Print("\n")
	fmt.Print("Confirm the password: ")
	password2, err := terminal.ReadPassword(int(os.Stdin.Fd()))
	if err != nil {
		return PasswordCredential{}, fmt.Errorf("while reading password from stdin: %w", err)
	}
	fmt.Print("\n")
	if string(password2) != string(password1) {
		return PasswordCredential{}, fmt.Errorf("passwords do not match")
	}
	argonedSecret := argon2.IDKey(password1, credential.Salt, 1, 64*1024, 4, 32)
	credential.EncryptedPassword, err = crypto.Encrypt(key, argonedSecret)
	if err != nil {
		return PasswordCredential{}, fmt.Errorf("while encrypting key: %w", err)
	}
	return credential, nil
}

func addFidoCred(key []byte, credentials []Credential, name string) (Credential, error) {
	var err error
	credential := FidoCredential{
		Salt:           crypto.MakeKey(),
		ClientDataHash: crypto.MakeKey(),
		CredenialName:  name,
	}
	devices, err := fido.GetPluggedDevices()
	if err != nil {
		return credential, fmt.Errorf("while getting devices: %w", err)
	}
	unknownDevices := filterKnownDevices(devices, credentials)
	if len(unknownDevices) == 0 {
		return FidoCredential{}, fmt.Errorf("No unknown devices plugged in")
	}
	device := unknownDevices[0]
	credential.AAGUID = device.Info.AAGUID
	fmt.Println("Touch the key to prime it")
	credential.CredentialID, err = fido.Setup("eeze", "", credential.ClientDataHash, device)
	if err != nil {
		return credential, fmt.Errorf("while priming key: %w", err)
	}
	fmt.Println("Touch the key to create passphrase")
	secret, err := fido.GetHmacSecret("eeze", "", credential.ClientDataHash,
		credential.CredentialID, credential.Salt, device)
	if err != nil {
		return credential, fmt.Errorf("while getting HMAC secret: %w", err)
	}
	argonedSecret := argon2.IDKey(secret, credential.Salt, 1, 64*1024, 4, 32)
	credential.EncryptedPassword, err = crypto.Encrypt(key, argonedSecret)
	if err != nil {
		return credential, fmt.Errorf("while encrypting key: %w", err)
	}
	return credential, nil
}

func filterKnownDevices(devices []fido.Device, credentials []Credential) []fido.Device {
	unknownDevices := []fido.Device{}
deviceLoop:
	for _, device := range devices {
		for _, credential := range credentials {
			if c, ok := credential.(*FidoCredential); ok {
				if bytes.Equal(device.Info.AAGUID, c.AAGUID) {
					continue deviceLoop
				}
			}
		}
		unknownDevices = append(unknownDevices, device)
	}
	return unknownDevices
}

func openCredential(credentials []Credential) ([]byte, error) {
	devices, err := fido.GetPluggedDevices()
	if err != nil {
		devices = []fido.Device{}
		if err.Error() == "Compiled without FIDO2 support" {
			err = nil
		} else {
			log.Println("WARN while getting plugged devices:", err)
		}
	}
	for _, device := range devices {
		for _, credential := range credentials {
			if c, ok := credential.(*FidoCredential); ok {
				if bytes.Equal(device.Info.AAGUID, c.AAGUID) {
					fmt.Fprintf(os.Stderr, "[%s]: ", c.Name())
					secret, err := fido.GetHmacSecret("eeze", "", c.ClientDataHash,
						c.CredentialID, c.Salt, device)
					fmt.Fprintln(os.Stderr, "")
					if err != nil {
						log.Println("WARN while getting HMAC secret:", err)
						continue
					}
					argonedSecret := argon2.IDKey(secret, c.Salt, 1, 64*1024, 4, 32)
					key, err := crypto.Decrypt(c.EncryptedPassword, argonedSecret)
					if err != nil {
						log.Println("WARN while decrypting key:", err)
						continue
					}
					return key, nil
				}
			}
		}
	}
	for _, credential := range credentials {
		if c, ok := credential.(*PasswordCredential); ok {
			fmt.Fprintf(os.Stderr, "[%s]: ", c.Name())
			password, err := terminal.ReadPassword(int(os.Stdin.Fd()))
			if err != nil {
				return []byte{}, fmt.Errorf("while reading password from stdin: %w", err)
			}
			fmt.Fprint(os.Stderr, "\n")
			if len(password) == 0 {
				continue
			}
			argonedSecret := argon2.IDKey(password, c.Salt, 1, 64*1024, 4, 32)
			key, err := crypto.Decrypt(c.EncryptedPassword, argonedSecret)
			if err != nil {
				return []byte{}, fmt.Errorf("while decrypting key: %w", err)
			}
			return key, nil
		}
	}
	return []byte{}, fmt.Errorf("No credential opened")
}
